import torch
import torchaudio
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import pandas as pd
from torchmetrics.classification import MulticlassCohenKappa
from torchmetrics import functional as FM
from datasets.Cremad_dataset import CREMADDataset
import json
import logging
from pathlib import Path
import random
import numpy as np
import matplotlib.pyplot as plt
import os
from .resnet import resnet18
import warnings

os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
logging.getLogger().setLevel(logging.WARNING)

class AudioAndVisionModel(torch.nn.Module):
    def __init__(
        self,
        num_classes,
        loss_fn,
        audio_module,
        video_module,
        audio_feature_dim,
        video_feature_dim,
        fusion_output_size,
        use_augmentation,
        train_dataset,
        batch_size = 32,
        device_idx=None,
    ):
        super(AudioAndVisionModel, self).__init__()
        self.audio_module = audio_module
        self.video_module = video_module
        self.train_dataset = train_dataset
        total_feature_dim = audio_feature_dim + video_feature_dim

        self.task_net = torch.nn.Sequential(
            torch.nn.Linear(total_feature_dim, num_classes),
        )
        self.num_classes = num_classes
        self.loss_fn = loss_fn
        self.loss_fn_none = torch.nn.CrossEntropyLoss(reduction='none')
        self.use_augmentation = use_augmentation


    def forward(self, audio, video, label=None, use_augmentation=False):

        video = video.permute(0, 2, 1, 3, 4).contiguous()
        audio_features = self.audio_module(audio)
        video_features = self.video_module(video)

        (_, C, H, W) = video_features.size()
        B = audio_features.size()[0]
        video_features = video_features.view(B, -1, C, H, W)
        video_features = video_features.permute(0, 2, 1, 3, 4)

        audio_features = F.adaptive_avg_pool2d(audio_features, 1)
        video_features = F.adaptive_avg_pool3d(video_features, 1)

        audio_features = torch.flatten(audio_features, 1)
        video_features = torch.flatten(video_features, 1)

        fusion_features = torch.cat([audio_features, video_features], dim=1)
        
        logits = self.task_net(fusion_features)
        
        loss = (
            self.loss_fn(logits, label) 
            if label is not None else label
        )      
  
        return (logits, loss)

    
class BaselineModel(pl.LightningModule):
    def __init__(self, params):
        for data_key in ["train_path", "dev_path", "audio_dir", "video_dir"]:
            if data_key not in params.keys():
                raise KeyError(
                    f"{data_key} is a required hparam in this model"
                )
        
        super(BaselineModel, self).__init__()
        self.model_params = params
        
        self.embedding_dim = self.model_params.get("embedding_dim", 300)
        self.audio_feature_dim = self.model_params.get("audio_feature_dim", 300)
        self.video_feature_dim = self.model_params.get("video_feature_dim", 300)
        self.output_path = Path(self.model_params.get("output_path", "model-outputs"))
        self.output_path.mkdir(exist_ok=True)
        
        self.max_epochs = self.model_params.get("max_epochs", 100)
        self.audio_transform = self._build_audio_transform()
        self.video_transform = self._build_video_transform()
        self.train_dataset = self._build_dataset("train")
        self.dev_dataset = self._build_dataset("val")
        
        self.automatic_optimization = False
        # set up model and training
        self.model = self._build_model()
        self.trainer_params = self._get_trainer_params()

        self.train_epoch_loss = 0.0  
        self.val_epoch_loss = 0.0 
        self.train_losses_per_epoch = []
        self.val_losses_per_epoch = []

        self.use_augmentation = self.model_params.get("use_augmentation")


    @property
    def hparams(self):
        return self.model_params

    def forward(self, audio, video, label=None, use_augmentation=False):
        return self.model(audio, video, label, use_augmentation)
    
    def on_train_epoch_start(self):
        self.train_epoch_loss = 0.0

    def training_step(self, batch, batch_nb):
        self.train()
        accumulate_grad_batches = self.hparams.get("accumulate_grad_batches", 1)
        optimizer = self.optimizers()
        sch = self.lr_schedulers()

        _, loss = self.forward(
            audio=batch["audio"], 
            video=batch["video"], 
            label=batch["label"],
            use_augmentation=False
            )
        loss = loss / accumulate_grad_batches

        self.manual_backward(loss)

        if (batch_nb + 1) % accumulate_grad_batches == 0:
            optimizer.step()
            optimizer.zero_grad()

        if self.trainer.is_last_batch:
            sch.step()

        self.log("train_loss", loss, prog_bar=True)

        self.train_epoch_loss =  self.train_epoch_loss + loss.item()
        return {"train_loss": loss}
    
    def on_train_epoch_end(self):
        avg_train_loss = self.train_epoch_loss / len(self.train_dataloader())
        self.train_losses_per_epoch.append(avg_train_loss)
        self.train_epoch_loss = 0.0 
        self.log("avg_train_loss", avg_train_loss)
    

    def on_validation_epoch_start(self):
        self.val_epoch_loss = 0.0

    def validation_step(self, batch, batch_nb):
        self.eval()
        logits, loss = self.forward(
            audio=batch["audio"], 
            video=batch["video"], 
            label=batch["label"],
            use_augmentation=False,
        )
        
        # Get actual predictions by taking argmax
        preds = torch.argmax(logits, dim=1)
        
        acc = FM.accuracy(preds, batch["label"], task="multiclass", num_classes=self.hparams.get("num_classes",6))
        self.log("val_acc", acc, prog_bar=True)
        self.log("val_loss", loss, prog_bar=True)
        self.val_epoch_loss = self.val_epoch_loss + loss.item()
        return {"batch_val_loss": loss, "batch_val_acc": acc}

    def on_validation_epoch_end(self):
        avg_val_loss = self.val_epoch_loss / len(self.val_dataloader())
        self.val_losses_per_epoch.append(avg_val_loss)

        self.val_epoch_loss = 0.0
        self.log("avg_val_loss", avg_val_loss)
        return {
            "val_loss": avg_val_loss
        }
    def configure_optimizers(self):
        # Separate parameters for main network
        optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.get("lr", 0.001), momentum=0.9, weight_decay=self.hparams.get("weight_decay", 1e-4))
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
        return [optimizer], [scheduler]
    
    # @pl.data_loader
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            shuffle=True,
            batch_size=self.hparams.get("batch_size", 4),
            num_workers=self.hparams.get("num_workers", 16),
        )

    # @pl.data_loader
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.dev_dataset, 
            shuffle=False, 
            batch_size=self.hparams.get("batch_size", 4), 
            num_workers=self.hparams.get("num_workers", 16),
        )
    
    ## Convenience Methods ##
    
    def fit(self):
        self.trainer = pl.Trainer(**self.trainer_params)
        self.trainer.fit(self)
        
    def _set_seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
    
    def test_step(self, batch):
        self.eval()
        audio, video, labels = batch["audio"], batch["video"], batch["label"]
        preds, _ = self.forward(audio, video)
        preds = torch.argmax(preds, dim=1)


        acc = FM.accuracy(preds, labels, task="multiclass", num_classes=6)
        f1 = FM.f1_score(preds, labels, task="multiclass", num_classes=6, average="macro")

        self.log("test_acc", acc, prog_bar=True)
        self.log("test_f1", f1, prog_bar=True)

        return {"test_acc": acc, "test_f1": f1}
    
    def on_test_epoch_end(self):
        # Access test_acc metrics directly through self.trainer.callback_metrics
        test_acc = self.trainer.callback_metrics["test_acc"]
        test_f1 = self.trainer.callback_metrics["test_f1"]

        self.log("test_acc", test_acc)
        self.log("test_f1", test_f1)

        return {"test_acc": test_acc, "test_f1": test_f1}

    def test(self, dataloader):
        self.eval()  
        all_preds = []
        all_labels = []

        with torch.no_grad():  
            for batch in dataloader:
                output = self.test_step(batch)
                all_preds.append(output["preds"])
                all_labels.append(output["labels"])

        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)


        test_acc = FM.accuracy(all_preds, all_labels, task="multiclass", num_classes=6)
        test_f1 = FM.f1_score(all_preds, all_labels, task="multiclass", num_classes=6, average="macro")

        print(f"Test Accuracy: {test_acc.item():.4f}")
        print(f"Test F1 Score: {test_f1.item():.4f}")

        return test_acc, test_f1

    def _build_audio_transform(self):
        return torchaudio.transforms.MelSpectrogram(
            sample_rate=16000,
            n_fft=1024,
            hop_length=512,
            n_mels=128
        )
        
    def _build_video_transform(self):
        video_dim = self.hparams.get("video_dim", 224)
        video_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(
                    size=(video_dim, video_dim)
                ),        
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    mean=(0.485, 0.456, 0.406), 
                    std=(0.229, 0.224, 0.225)
                ),
            ]
        )
        return video_transform
    
    def _build_dataset(self, mode):

        return CREMADDataset(
            mode = mode,
            audio_path=self.hparams.get("audio_dir"),
            visual_path=self.hparams.get("video_dir"),
            train_csv = self.hparams.get("train_path"),
            val_csv = self.hparams.get("dev_path"),
            stat_csv = self.hparams.get("stat_path"),
        )
    
    def _build_model(self):

        audio_module = resnet18(modality = 'audio')
        video_module = resnet18(modality = 'visual')

        return AudioAndVisionModel(
            num_classes=self.hparams.get("num_classes", 6),
            loss_fn=torch.nn.CrossEntropyLoss(),
            audio_module=audio_module,
            video_module=video_module,
            audio_feature_dim=self.audio_feature_dim,
            video_feature_dim=self.video_feature_dim,
            fusion_output_size=self.hparams.get(
                "fusion_output_size", 512
            ),
            use_augmentation=self.hparams.get("use_augmentation", True),
            train_dataset=self.train_dataset,
            batch_size=self.hparams.get("batch_size", 32),
            device_idx=self.hparams.get("devices", [1]),
        )
    
    def _get_trainer_params(self):
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            dirpath=self.output_path,
            filename='{epoch}-{val_loss:.3f}-{val_acc:.3f}',
            monitor='val_acc',  # Metric to monitor
            mode='max',         # Save model when metric is maximized
            save_top_k=1,      # Save only the best model
            verbose=True
        )

        early_stop_callback = pl.callbacks.EarlyStopping(
            monitor=self.hparams.get(
                "early_stop_monitor", "val_acc"
            ),
            min_delta=self.hparams.get(
                "early_stop_min_delta", 0.001
            ),
            patience=self.hparams.get(
                "early_stop_patience", 5
            ),
            verbose=self.hparams.get("verbose", True),
            mode='max',  # Stop training when metric stops increasing
        )

        trainer_params = {
            "callbacks": [checkpoint_callback, early_stop_callback],

            "devices": self.hparams.get("devices", [1]),
            "max_epochs": self.hparams.get("max_epochs", 100),

            "num_sanity_val_steps": 0,
            "reload_dataloaders_every_n_epochs": 1
        }
        return trainer_params